# -*- coding: utf-8 -*-
import json
import jieba
import pickle
from gensim import corpora, models, similarities
from os.path import exists
from warnings import filterwarnings
filterwarnings('ignore')  # 不打印警告

class CONF:
    path = '对话语料.json'          # 语料路径
    model_path = '对话模型.pk'      # 模型路径
    
class Model:
    def __init__(self, question, answer, dictionary, tfidf, index):
        self.dictionary = dictionary    # 字典
        self.tfidf = tfidf              # 词袋模型转tfidf
        self.index = index              # 稀疏矩阵建立索引
        self.question = question        # 语料--问题数组
        self.answer = answer            # 语料--答案数组（与问题一一对应）

    """模型初始化"""
    @classmethod
    def initialize(cls, config):
        if exists(config.model_path):
            # 模型读取
            question, answer, dictionary, tfidf, index = cls.__load_model(config.model_path)
        else:
            # 语料读取
            if exists(config.path):
                data = load_json(config.path)
            else:
                data = get_data(config.path)
            # 模型训练
            question, answer, dictionary, tfidf, index = cls.__train_model(data)
            # 模型保存
            cls.__save_model(config.model_path, question, answer, dictionary, tfidf, index)

        return cls(question, answer, dictionary, tfidf, index)

    @staticmethod
    def __train_model(data):
        """训练模型"""
        # 划分问题和答案
        question_list = []
        answer_list = []
        for line in data:
            question_list.append(line['question'])
            answer_list.append(line['answer'])

        # 对问题进行分词
        qcut = []
        for i in question_list:
            data1 = ""
            this_data = jieba.cut(i)
            for item in this_data:
                data1 += item + " "
            qcut.append(data1)
        docs = qcut

        # 将二维数组转为字典
        tall = [[w1 for w1 in doc.split()] for doc in docs]
        print("tall = %s" % tall)
        dictionary = corpora.Dictionary(tall)
        # gensim的doc2bow实现词袋模型        
        corpus = [dictionary.doc2bow(text) for text in tall]
        # corpus是一个返回bow向量的迭代器。下面代码将完成对corpus中出现的每一个特征的IDF值的统计工作
        tfidf = models.TfidfModel(corpus)
        # 通过token2id得到特征数
        num = len(dictionary.token2id.keys())
        # 稀疏矩阵相似度，从而建立索引
        index = similarities.SparseMatrixSimilarity(tfidf[corpus], num_features=num)
        return question_list, answer_list, dictionary, tfidf, index

    @staticmethod
    def __save_model(model_path, question, answer, dictionary, tfidf, index):
        """模型的保存"""
        model = {}
        model['question'] = question
        model['answer'] = answer
        model['dictionary'] = dictionary
        model['tfidf'] = tfidf
        model['index'] = index
        with open(model_path, "wb") as fh:
            pickle.dump(model, fh)

    @staticmethod
    def __load_model(model_path):
        """模型的保存"""
        with open(model_path, "rb") as fh:
            model = pickle.load(fh)
        question = model['question']
        answer = model['answer']
        dictionary = model['dictionary']
        tfidf = model['tfidf']
        index = model['index']
        return question, answer, dictionary, tfidf, index

    def get_answer(self, question, digalog_id = 1):
        """获取问题的答案"""
        # 对输入的问题进行分词
        data3 = jieba.cut(question)
        data31 = ""
        for item in data3:
            data31 += item + " "
        new_doc = data31
        print("new_doc = %s" % new_doc)
        print("new_doc.split() = %s" % new_doc.split())
        # 计算该问题的答案
        new_vec = self.dictionary.doc2bow(new_doc.split())
        sim = self.index[self.tfidf[new_vec]]
        position = sim.argsort()[-1]
        answer = self.answer[position]

        return answer, digalog_id
    
    
def load_json(filename, encoding='utf-8'):
    """ 读取json数据"""
    filename = filename
    with open(filename, encoding=encoding) as file_obj:
        rnt = json.load(file_obj)
    return rnt['data']

def save_json(filename, data, encoding='utf-8'):
    """保存json"""
    with open(filename, 'w', encoding=encoding) as file_obj:
        json.dump({"data": data}, file_obj, ensure_ascii=False)

def get_data(filename):
    """获取对话材"""
    # question_list 与 answer_list 一一对应
    question_list = ["在吗？", "在干嘛？", "我饿了", "我想看电影。"]
    answer_list = ["亲，在的。", "在想你呀！", "来我家，做饭给你吃~", "来我家，我家有30寸大电视。"]
    data = []
    for question, answer in zip(question_list, answer_list):
        data.append({'question': question, "answer":answer})
    save_json(filename, data)
    return data


if __name__ == '__main__':
    model = Model.initialize(config=CONF)
    question_list = ["在吗？", "在干嘛？", "我饿了", "我肚子饿了", "我肚子好饿", "有好看电影介绍吗？我想看"]
    for line in question_list:
        rnt, digalog_id = model.get_answer(line)
        print("\033[031m女神：%s\033[0m" % line)
        print("\033[036m尬聊：%s\033[0m" % rnt)